# Author: Mark Richardson
# Purpose: Diagnose convergence of Bayesian estimates of agency performance

## ----setup, include=FALSE-------------------------------------------------------------------------------------
knitr::opts_chunk$set(echo = FALSE, warning = FALSE, message = FALSE)


## -------------------------------------------------------------------------------------------------------------

# Proofed 11/05/2023

# Load packages
library(dplyr)
library(rstan)
library(bayesplot)
library(ggplot2)

# Load data
load("02_performance_fitted_models.RData")



## ---- diagnostic_function-------------------------------------------------------------------------------------

# Create a function to calculate diagnostics for model parameters

# Pars is a character string to subset the model parameters
stanDiag <- function(stan_object, pars = NULL) {
  
  # Get model parameters
  if (length(pars) == 0) {
    
    p_smy <- summary(stan_object)$summary
    
  } else {
    
    p_smy <- summary(stan_object, pars = pars)$summary
    
  }
  
  # Get vector of parameters
  p <- rownames(p_smy) 
  
  # Get array of draws
  array_of_draws <- as.array(stan_object)

  # Initialize vectors to store diagnostics
  p_r_hat <- vector()
  
  p_ess_bulk <- vector()
  
  p_ess_tail <- vector()
  
  for (i in 1:length(p)) {
    
    p_r_hat[i] <- rstan::Rhat(array_of_draws[,,p[i]])
    
    p_ess_bulk[i] <- rstan::ess_bulk(array_of_draws[,,p[i]])
    
    p_ess_tail[i] <- rstan::ess_tail(array_of_draws[,,p[i]])
    
  }
  
  names(p_r_hat) <- p
  names(p_ess_bulk) <- p
  names(p_ess_tail) <- p
  
  rstan_diags <- list(r_hat = p_r_hat,
                      ess_bulk = p_ess_bulk,
                      ess_tail = p_ess_tail)
  
  return(rstan_diags)
  
}



## ---- model_diag_lists----------------------------------------------------------------------------------------

diag_inf_sd_hier <- stanDiag(perf_inf_sd_hier)

diag_naive_hier <- stanDiag(perf_naive_sd_hier)

diag_inf_sd <- stanDiag(perf_inf_sd)

diag_naive <- stanDiag(perf_naive)



## -------------------------------------------------------------------------------------------------------------
mcmc_rhat_hist(diag_inf_sd_hier$r_hat)


## -------------------------------------------------------------------------------------------------------------
mcmc_neff_hist(neff_ratio(perf_inf_sd_hier))


## -------------------------------------------------------------------------------------------------------------

ggplot(data = tibble(ess_bulk = diag_inf_sd_hier$ess_bulk)) +
  geom_histogram(aes(ess_bulk), binwidth = 1000, color = "black", fill = "dodgerblue", alpha = 0.8) +
  geom_vline(xintercept = 400, linetype = "dashed") +
  theme_bw()



## -------------------------------------------------------------------------------------------------------------

ggplot(data = tibble(ess_tail = diag_inf_sd_hier$ess_tail)) +
  geom_histogram(aes(ess_tail), binwidth = 1000, color = "black", fill = "dodgerblue", alpha = 0.8) +
  geom_vline(xintercept = 400, linetype = "dashed") +
  theme_bw()



## -------------------------------------------------------------------------------------------------------------
stan_mcse(perf_inf_sd_hier)


## -------------------------------------------------------------------------------------------------------------
mcmc_rhat_hist(diag_naive_hier$r_hat)


## -------------------------------------------------------------------------------------------------------------
mcmc_neff_hist(neff_ratio(perf_naive_sd_hier))


## -------------------------------------------------------------------------------------------------------------

ggplot(data = tibble(ess_bulk = diag_naive_hier$ess_bulk)) +
  geom_histogram(aes(ess_bulk), binwidth = 1000, color = "black", fill = "dodgerblue", alpha = 0.8) +
  geom_vline(xintercept = 400, linetype = "dashed") +
  theme_bw()



## -------------------------------------------------------------------------------------------------------------

ggplot(data = tibble(ess_tail = diag_naive_hier$ess_tail)) +
  geom_histogram(aes(ess_tail), binwidth = 1000, color = "black", fill = "dodgerblue", alpha = 0.8) +
  geom_vline(xintercept = 400, linetype = "dashed") +
  theme_bw()



## -------------------------------------------------------------------------------------------------------------
stan_mcse(perf_naive_sd_hier)


## -------------------------------------------------------------------------------------------------------------
mcmc_rhat_hist(diag_inf_sd$r_hat)


## -------------------------------------------------------------------------------------------------------------
mcmc_neff_hist(neff_ratio(perf_inf_sd))


## -------------------------------------------------------------------------------------------------------------

ggplot(data = tibble(ess_bulk = diag_inf_sd$ess_bulk)) +
  geom_histogram(aes(ess_bulk), binwidth = 1000, color = "black", fill = "dodgerblue", alpha = 0.8) +
  geom_vline(xintercept = 400, linetype = "dashed") +
  theme_bw()



## -------------------------------------------------------------------------------------------------------------

ggplot(data = tibble(ess_tail = diag_inf_sd$ess_tail)) +
  geom_histogram(aes(ess_tail), binwidth = 1000, color = "black", fill = "dodgerblue", alpha = 0.8) +
  geom_vline(xintercept = 400, linetype = "dashed") +
  theme_bw()



## -------------------------------------------------------------------------------------------------------------
stan_mcse(perf_inf_sd)


## -------------------------------------------------------------------------------------------------------------
mcmc_rhat_hist(diag_naive$r_hat)


## -------------------------------------------------------------------------------------------------------------
mcmc_neff_hist(neff_ratio(perf_naive))


## -------------------------------------------------------------------------------------------------------------

ggplot(data = tibble(ess_bulk = diag_naive$ess_bulk)) +
  geom_histogram(aes(ess_bulk), binwidth = 1000, color = "black", fill = "dodgerblue", alpha = 0.8) +
  geom_vline(xintercept = 400, linetype = "dashed") +
  theme_bw()



## -------------------------------------------------------------------------------------------------------------

ggplot(data = tibble(ess_tail = diag_naive$ess_tail)) +
  geom_histogram(aes(ess_tail), binwidth = 1000, color = "black", fill = "dodgerblue", alpha = 0.8) +
  geom_vline(xintercept = 400, linetype = "dashed") +
  theme_bw()



## -------------------------------------------------------------------------------------------------------------
stan_mcse(perf_naive)

